{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "d9b9eaa6-a12f-4cf8-a4c5-e8ac2c15d15b",
   "metadata": {
    "id": "d9b9eaa6-a12f-4cf8-a4c5-e8ac2c15d15b"
   },
   "source": [
    "# 🔍 Predicting Item Prices from Descriptions (Part 3)\n",
    "---\n",
    "-  Data Curation & Preprocessing\n",
    "- Model Benchmarking – Traditional ML vs LLMs\n",
    "- ➡️E5 Embeddings & RAG\n",
    "- Fine-Tuning GPT-4o Mini\n",
    "- Evaluating LLaMA 3.1 8B Quantized\n",
    "- Fine-Tuning LLaMA 3.1 with QLoRA\n",
    "- Evaluating Fine-Tuned LLaMA\n",
    "- Summary & Leaderboard\n",
    "\n",
    "---\n",
    "\n",
    "# 🧠 Part 3: E5 Embeddings & RAG\n",
    "\n",
    "- 🧑‍💻 Skill Level: Advanced\n",
    "- ⚙️ Hardware: ⚠️ GPU required for embeddings (400K items) - use Google Colab\n",
    "- 🛠️ Requirements: 🔑 HF Token, Open API Key\n",
    "- Tasks:\n",
    "    - Preprocessed item descriptions\n",
    "    - Generated and stored embeddings in ChromaDB\n",
    "    - Trained XGBoost on embeddings, pushed to HF Hub, and ran predictions\n",
    "    - Predicted prices with GPT-4o Mini using RAG\n",
    "\n",
    "Is Word2Vec enough for XGBoost, or do contextual E5 embeddings perform better?\n",
    "\n",
    "Does retrieval improve price prediction for GPT-4o Mini?\n",
    "\n",
    "Let’s find out.\n",
    "\n",
    "⚠️ This notebook assumes basic familiarity with RAG and contextual embeddings.\n",
    "We use the same E5 embedding space for both XGBoost and GPT-4o Mini with RAG, enabling a fair comparison.\n",
    "Embeddings are stored and queried via ChromaDB — no LangChain is used for creation or retrieval.\n",
    "\n",
    "---\n",
    "📢 Find more LLM notebooks on my [GitHub repository](https://github.com/lisekarimi/lexo)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8e2af5e-03cc-46dc-8a8b-37cb102d0e92",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "d8e2af5e-03cc-46dc-8a8b-37cb102d0e92",
    "outputId": "905907cc-81c5-4a3b-e7c8-9e237e594a09"
   },
   "outputs": [],
   "source": [
    "# Install required packages in Google Colab\n",
    "%pip install -q tqdm huggingface_hub numpy sentence-transformers datasets chromadb xgboost"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ce6a892-b357-4132-b9c0-a3142a0244c8",
   "metadata": {
    "id": "4ce6a892-b357-4132-b9c0-a3142a0244c8"
   },
   "outputs": [],
   "source": [
    "# imports\n",
    "\n",
    "import math\n",
    "import chromadb\n",
    "import re\n",
    "import joblib\n",
    "import os\n",
    "from tqdm import tqdm\n",
    "import gc\n",
    "from huggingface_hub import login, HfApi\n",
    "import numpy as np\n",
    "from sentence_transformers import SentenceTransformer\n",
    "from datasets import load_dataset\n",
    "from google.colab import userdata\n",
    "from xgboost import XGBRegressor\n",
    "from openai import OpenAI\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "yBH-mvV0QBiw",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "yBH-mvV0QBiw",
    "outputId": "b4b6df10-dc05-4dbe-dd8b-55bae5a2b7af"
   },
   "outputs": [],
   "source": [
    "# Mount Google Drive to access persistent storage\n",
    "\n",
    "from google.colab import drive\n",
    "drive.mount('/content/drive')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3OUI1jQYyaeX",
   "metadata": {
    "id": "3OUI1jQYyaeX"
   },
   "outputs": [],
   "source": [
    "# Google Colab User Data\n",
    "# Ensure you have set the following in your Google Colab environment:\n",
    "openai_api_key = userdata.get(\"OPENAI_API_KEY\")\n",
    "hf_token = userdata.get('HF_TOKEN')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "99f6f632",
   "metadata": {},
   "outputs": [],
   "source": [
    "openai = OpenAI(api_key=openai_api_key)\n",
    "login(hf_token, add_to_git_credential=True)\n",
    "\n",
    "# Configuration\n",
    "ROOT = \"/content/drive/MyDrive/deal_finder\"\n",
    "CHROMA_PATH = f\"{ROOT}/chroma\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "FF-HryRnDXm5",
   "metadata": {
    "id": "FF-HryRnDXm5"
   },
   "outputs": [],
   "source": [
    "# Helper class for evaluating model predictions\n",
    "\n",
    "GREEN = \"\\033[92m\"\n",
    "YELLOW = \"\\033[93m\"\n",
    "RED = \"\\033[91m\"\n",
    "RESET = \"\\033[0m\"\n",
    "COLOR_MAP = {\"red\":RED, \"orange\": YELLOW, \"green\": GREEN}\n",
    "\n",
    "class Tester:\n",
    "\n",
    "    def __init__(self, predictor, data, title=None, size=250):\n",
    "        self.predictor = predictor\n",
    "        self.data = data\n",
    "        self.title = title or predictor.__name__.replace(\"_\", \" \").title()\n",
    "        self.size = size\n",
    "        self.guesses = []\n",
    "        self.truths = []\n",
    "        self.errors = []\n",
    "        self.sles = []\n",
    "        self.colors = []\n",
    "\n",
    "    def color_for(self, error, truth):\n",
    "        if error<40 or error/truth < 0.2:\n",
    "            return \"green\"\n",
    "        elif error<80 or error/truth < 0.4:\n",
    "            return \"orange\"\n",
    "        else:\n",
    "            return \"red\"\n",
    "\n",
    "    def run_datapoint(self, i):\n",
    "        datapoint = self.data[i]\n",
    "        guess = self.predictor(datapoint)\n",
    "        truth = datapoint[\"price\"]\n",
    "        error = abs(guess - truth)\n",
    "        log_error = math.log(truth+1) - math.log(guess+1)\n",
    "        sle = log_error ** 2\n",
    "        color = self.color_for(error, truth)\n",
    "        # title = datapoint[\"text\"].split(\"\\n\\n\")[1][:20] + \"...\"\n",
    "        self.guesses.append(guess)\n",
    "        self.truths.append(truth)\n",
    "        self.errors.append(error)\n",
    "        self.sles.append(sle)\n",
    "        self.colors.append(color)\n",
    "        # print(f\"{COLOR_MAP[color]}{i+1}: Guess: ${guess:,.2f} Truth: ${truth:,.2f} Error: ${error:,.2f} SLE: {sle:,.2f} Item: {title}{RESET}\")\n",
    "\n",
    "    def chart(self, title):\n",
    "        # max_error = max(self.errors)\n",
    "        plt.figure(figsize=(12, 8))\n",
    "        max_val = max(max(self.truths), max(self.guesses))\n",
    "        plt.plot([0, max_val], [0, max_val], color='deepskyblue', lw=2, alpha=0.6)\n",
    "        plt.scatter(self.truths, self.guesses, s=3, c=self.colors)\n",
    "        plt.xlabel('Ground Truth')\n",
    "        plt.ylabel('Model Estimate')\n",
    "        plt.xlim(0, max_val)\n",
    "        plt.ylim(0, max_val)\n",
    "        plt.title(title)\n",
    "\n",
    "        # Add color legend\n",
    "        from matplotlib.lines import Line2D\n",
    "        legend_elements = [\n",
    "            Line2D([0], [0], marker='o', color='w', label='Accurate (green)', markerfacecolor='green', markersize=8),\n",
    "            Line2D([0], [0], marker='o', color='w', label='Medium error (orange)', markerfacecolor='orange', markersize=8),\n",
    "            Line2D([0], [0], marker='o', color='w', label='High error (red)', markerfacecolor='red', markersize=8)\n",
    "        ]\n",
    "        plt.legend(handles=legend_elements, loc='upper right')\n",
    "\n",
    "        plt.show()\n",
    "\n",
    "\n",
    "    def report(self):\n",
    "        average_error = sum(self.errors) / self.size\n",
    "        rmsle = math.sqrt(sum(self.sles) / self.size)\n",
    "        hits = sum(1 for color in self.colors if color==\"green\")\n",
    "        title = f\"{self.title} Error=${average_error:,.2f} RMSLE={rmsle:,.2f} Hits={hits/self.size*100:.1f}%\"\n",
    "        self.chart(title)\n",
    "\n",
    "    def run(self):\n",
    "        self.error = 0\n",
    "        for i in range(self.size):\n",
    "            self.run_datapoint(i)\n",
    "        self.report()\n",
    "\n",
    "    @classmethod\n",
    "    def test(cls, function, data):\n",
    "        cls(function, data).run()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6f82b230-2e03-4b1e-9be5-926fcd19acbe",
   "metadata": {
    "id": "6f82b230-2e03-4b1e-9be5-926fcd19acbe"
   },
   "source": [
    "## 📥 Load Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ae00568",
   "metadata": {},
   "outputs": [],
   "source": [
    "# #If you face NotImplementedError: Loading a dataset cached in a LocalFileSystem is not supported run:\n",
    "# %pip install -U datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55f1495b-f343-4152-8739-3a99f5ac405d",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 177,
     "referenced_widgets": [
      "6e7c01d666f64fa58d6a059cc8d8f323",
      "597b7155767441e6a0283a19edced00f",
      "cf1360550eaa49a0867f55db8b8c4c77",
      "94f26137cccf47f6a36d9325bc8f5b9c",
      "a764b97f3dcd480c8860dde979e5e114",
      "f1ec9a46c9ce4e038f3051bbd1b2c661",
      "992f46ae91554731987b4baf79ba1bbd",
      "b4abe22402fe40fd82b7fe93b4bc06f3",
      "57ec058518734e3dbd27324cbba243c0",
      "f101230e8a9a431d85ee2f8e51add7ad",
      "e196658b093746588113240a60336437",
      "cb06a4d26cb84c708857b683d1e84c12",
      "e82ad07ba22e465cbe0232c504c3b693",
      "c4e0ed1165f54393aaec24cd4624d562",
      "295a3c6662034aaaab4d2e0192d1d1ce",
      "c38aff0c91a849feb547e78156c2c347",
      "69647c5595874c3185cebf6813ee908c",
      "1036b1af4b154916a3d4f16f5ed799eb",
      "e6347ff832cc4c04aef86594ea5a9e64",
      "01c63224aa6a4f0c9c88a4d85527e767",
      "1db34b9a4f1f42a897345b5a6630ced6",
      "9293f2d745024d7facb68e04cc188850",
      "26f6ec91efaf42909cec172fafe55987",
      "c1131f0324b0498da9bc59720e867eb6",
      "3e58017527a04634a489a33ed53fd312",
      "06cd89f57d08466c875d179e79e3ecd2",
      "2e0aa0aa87a04419a277f303f577f7ff",
      "8fa0fe1992db42a997e7cd3ee08bd09e",
      "accb1d5142a9498da0117f746fedd691",
      "fcc2fc2f82e2441995b9e61b23b9b91e",
      "da93fe316dd24cb48538b52ef2eaf6b5",
      "5cea58775faf41829c04d2a84e3e2c31",
      "1914ec7959d143d09a55da324bbcd47b",
      "a3d3504148df46f59b6770fb377e2bb6",
      "b088b9a503e24f179741d40d21a730d9",
      "b77dcf4632954d0c9c3b6d441c5f684d",
      "4cc8b3c4d9934f24a94b4601ab7816b5",
      "c093f1c0806a43b79594ddac856a301c",
      "9f4d9ac1aa074ed6b0248a4b18fde7db",
      "c00785b8fdda409e9cb435abbb0466da",
      "612e211af4cd46eb9d2f3148d1c7cb0b",
      "86f93c663cc446adbc6366a528cb01b0",
      "dd42911451ec48e086c1c99e76492321",
      "5b942241f11c4f2ab086f0f289f99a03",
      "d28a5c6172f74c0f8bbd2d949455f22e",
      "0e67b2055f214eb691b4b54d9431bdd8",
      "f81c4dc72b3b4b40a6a70528db732482",
      "043a355b6a85471ba0142eb25e2c9eb0",
      "8682bfab79a8409499797a3307e4d64d",
      "55a837644bb643ac864fa1a674e665c8",
      "33aae5a98bf5433b813ff8216e015089",
      "56eedfc5ba6642dc8443ab60f5f09b8c",
      "a1b710c227a84ea1a55c310084f13a93",
      "0d4bc0d0e88a4c77a202f9c11b2ee2a9",
      "20858379c2cd45d59070b18149d6e925"
     ]
    },
    "id": "55f1495b-f343-4152-8739-3a99f5ac405d",
    "outputId": "37317fe6-b560-4ad0-c7d6-66517fd67c42"
   },
   "outputs": [],
   "source": [
    "HF_USER = \"lisekarimi\"\n",
    "DATASET_NAME = f\"{HF_USER}/pricer-data\"\n",
    "\n",
    "dataset = load_dataset(DATASET_NAME)\n",
    "train = dataset['train']\n",
    "test = dataset['test']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "85880d79-f1ba-4ee8-a039-b6acea84562c",
   "metadata": {
    "id": "85880d79-f1ba-4ee8-a039-b6acea84562c"
   },
   "outputs": [],
   "source": [
    "print(train[0][\"text\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88842541-d73b-4fae-a550-6dedf8fab633",
   "metadata": {
    "id": "88842541-d73b-4fae-a550-6dedf8fab633"
   },
   "outputs": [],
   "source": [
    "print(train[0][\"price\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7b8a9a5b-f74d-487d-a400-d157fea8c979",
   "metadata": {
    "id": "7b8a9a5b-f74d-487d-a400-d157fea8c979"
   },
   "source": [
    "## 📦 Embed + Save Training Data to Chroma\n",
    "- No LangChain used.\n",
    "- We use `intfloat/e5-small-v2` for embeddings:\n",
    "    - Fast, high-quality, retrieval-tuned\n",
    "    - **Requires 'passage:' prefix**\n",
    "- We embed item descriptions and store them in ChromaDB, with price saved as metadata."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b95a87a8-2136-4e03-a36c-42e5d53a3e28",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 337,
     "referenced_widgets": [
      "8216f5d45e9345e493a43b8cbbe6598a",
      "ec3854658f8448fc8463e8635889f700",
      "7a90822b2aff4d5cb926442f01a77a9b",
      "9518c3af589744cfbbb51f87d68f216e",
      "327044765c044384a14be4e660bb152f",
      "0b773d68d2394d80a2baf73c1808752a",
      "21568b9954c8411d863baa7385df624f",
      "0a08828a0ba4430ea6e039949f220b5b",
      "3d5a51cfb5f44eecbf80d46e2e4608fd",
      "313f059a82104a9394182f6dcdb0bfb4",
      "6a625748afc84fe89a8af7a4ef638675",
      "ebe43cd30e414f31ab52614c6e9f9f2b",
      "88c29992adaa44af857e3216f7e53e60",
      "0528af78cef844e8a2b489dcb8fce049",
      "8cbccd78a79447158f02caadfa7d805f",
      "076ce072490c493ba5b3c431f6166eda",
      "dd7780038f8a4cd3837972c78b6583bc",
      "9e285e2b58934552b98edd998b82a678",
      "338efda3245a4989a9b3ee0795949bb8",
      "136dfb68394742ea98d9eb845730846c",
      "891d821725b6457c9d06737bf75fe3ed",
      "14feb4e20339465d966a6a80504eb819",
      "c02b637785324b9eb88e6a2c00cb986b",
      "3635da14e6f04e8f90548eb6381290a8",
      "1314757f404e47f5b0f6fa4de8537863",
      "9e5f2478e931476d882e471c7f66aaeb",
      "4ad885d69d9f492c960ca53426189707",
      "992d5e88d7844a52a283c0e19475ab78",
      "43eaec936c774e3380ae4ff1a823f3dc",
      "ceeb11b317ac4d37b59641024f77265f",
      "5e0371de53164830b4e8c2b6954b5947",
      "63a729492e8a4a759d75b769cbb3e1e7",
      "14dde2c87b7b4c9ea16d48732108dcd7",
      "f50717b099d142be95390ae8f1e99e6a",
      "ffa64c304dab4ef18e9ef50ac1625cd6",
      "f358351612004f64adffb931c3130603",
      "7593358526ae4a87bf4be0eb1bcfc076",
      "51536b45f5674d498272dc7b2def635d",
      "8fbe2a3fc07943e7bf0fdc927bab795a",
      "6b265cc65d5a42638572c1776faafdb1",
      "39fa86a7760d43c793eb8ef27475af7d",
      "eee5113e2dd1402faf76d00f07d8e0af",
      "6792ed7123724b2d8091bc8d36255e68",
      "e35094b24c154340bb1b3ebba7ac0a0d",
      "dd63bb6ffed34b6687a0c79d8af93fb7",
      "32080bc9381c449ab63794655ec6d714",
      "eb7aa289fefc465d98edeed9ce2bff51",
      "53fae218b4b74863af5fe53a66a5f7ef",
      "35bc6d95c60f4c3d8ddc6b3b0845ff7e",
      "f4765ca278ad4da4b465bd2920a21320",
      "7ac6ead5baef4f30aff170a30a9a7977",
      "e7adb5eb38d54b29b734d207982411c8",
      "8f4f51b75af74daa9b9ad6696760109c",
      "ae4db932b7544c6cb9ff668fa954addd",
      "be63f07eedbd4d46ac4913df45216108",
      "2e47d9e7b36a4ec69a9071930671ae8e",
      "7b1c7f9bf0e8412abb66bcfc24cf9668",
      "5c8742d3f663470e9977d006e83314b7",
      "74ec67e07ee0477eb41e21093ae82858",
      "4b60a8f023bc4d759bc197b11bf4e160",
      "7a090f162fa84568a5e486ba935c3ed1",
      "8b650428a6834f5d8ebe62ad327493e0",
      "5c4d22bce82546d28a8b0c041895c8e3",
      "16121b830a2948afb3ca8eb54e27a678",
      "0305a4b4408f4562b87b58098148326d",
      "68f07b5b7ad447ce9a87023d872c2e73",
      "2156a5ced089414c99a1bb8dd3a0b3b7",
      "2e6cd134c70e455a85c47b1575135883",
      "f4264985b5cc4a0f970a088fb90b8bcf",
      "71d790bf25324e6dbb5372f636c53da9",
      "dac3ba29ee4d4083a9abca7eab632534",
      "5c75c020a1914da680340fe826f3f58d",
      "195e6dfb82c84f0191838acbbfe38126",
      "b06adcaf8d4c497897ed3625f3afb4eb",
      "d4ab3971183a4e8fa10402e3542e6466",
      "444ca1f5213241c2bc71fa9ebe9ac3ca",
      "34d571f76ef845f4bc272a5e05491c31",
      "e8ee76b022d64b2cb24a2cb7b61aeef7",
      "8c9ac87788b04ae6899f3b62fdc3ed0d",
      "431b638c435444c38e50a09573b8f31b",
      "0430f22e24d14171b83261faa090f349",
      "0fa5ae935a554461b086a4b81470b9ad",
      "f072e665d27e442ab4d0e2eb33c98db9",
      "fd3b1885c39c4b70b083d7fddf74d4b6",
      "f77051cb151645559223ecf835426688",
      "0e17661f878948598703ee7942e5e1a2",
      "fca913c6cfff48099d1744d5b091fc46",
      "085baf51ecef46318ceafbaba2bb4490",
      "52309039c2d8421bbb8e99f63f5ba91f",
      "f4233cd960ea4f549734a5b1e1da5e2e",
      "42ce1b7765f547cd9ecd8b428ec1c718",
      "e72a08514d3b42d2b5fbf87a920bcdf0",
      "ad05cf4c0ed44341aa3cd2cbd22b513d",
      "db9915d53d784b85accebe1552c4e7e1",
      "9519b6d9bf1b45e3b56da4c28d2aeb2e",
      "cfeb0597708b49fa9b65342e1ac446ae",
      "e29617eff6fd4199a74b670198ba2a69",
      "1cea197a15d94654a0e792318435d707",
      "89dcb96670a8433593e3452fad3c9210",
      "0802085388be453b8fe5edee7e0a01ef",
      "1ed257f19b8b44ee85f09e10178ae52f",
      "04107981561149cba5baf74ccba87aa6",
      "09afb010020e4b2f91d7cdbdca316962",
      "b11b51beaa54474cb7682110bd2d24ae",
      "47822470ddf842cd9e3368090549a2b5",
      "835bce5d87a2417c9b6a5b27627447dc",
      "5ca06dd536d44de784984a492d23573f",
      "8e75bdb4469e497c8f021ebde7c6c9b3",
      "7f4d4f8ece1d4651a2186f10a0cc25a5",
      "92036442af5f4b698f2a54ecba4650e2"
     ]
    },
    "id": "b95a87a8-2136-4e03-a36c-42e5d53a3e28",
    "outputId": "6094328e-8c33-4b40-80e9-08c5cfb3e277"
   },
   "outputs": [],
   "source": [
    "# Load embedding model\n",
    "model_embedding = SentenceTransformer(\"intfloat/e5-small-v2\", device='cuda')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "733cf41d-e81e-4cfc-b597-67da02dbc3cf",
   "metadata": {
    "id": "733cf41d-e81e-4cfc-b597-67da02dbc3cf"
   },
   "outputs": [],
   "source": [
    "# Init Chroma\n",
    "client = chromadb.PersistentClient(path=CHROMA_PATH)\n",
    "collection = client.get_or_create_collection(name=\"price_items\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f493c7d-1c72-40f9-a5c6-63c7f6b1cf2c",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 91
    },
    "id": "1f493c7d-1c72-40f9-a5c6-63c7f6b1cf2c",
    "outputId": "72627732-4eee-4d9a-c8cb-0c42e2541a80"
   },
   "outputs": [],
   "source": [
    "# Format description function (no price in text)\n",
    "def description(item):\n",
    "    text = item[\"text\"].replace(\"How much does this cost to the nearest dollar?\\n\\n\", \"\")\n",
    "    text = text.split(\"\\n\\nPrice is $\")[0]\n",
    "    return f\"passage: {text}\"\n",
    "\n",
    "description(train[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f44bf613-adf6-4993-bf7b-6aa9fad21a03",
   "metadata": {
    "id": "f44bf613-adf6-4993-bf7b-6aa9fad21a03"
   },
   "outputs": [],
   "source": [
    "batch_size = 300    # how many items to insert into Chroma at once\n",
    "encode_batch_size = 1024  # how many items to encode at once in GPU memory\n",
    "\n",
    "for i in tqdm(range(0, len(train), batch_size), desc=\"Processing batches\"):\n",
    "\n",
    "    end_idx = min(i + batch_size, len(train))\n",
    "\n",
    "    # Collect documents and metadata\n",
    "    documents = [description(train[j]) for j in range(i, end_idx)]\n",
    "    metadatas = [{\"price\": train[j][\"price\"]} for j in range(i, end_idx)]\n",
    "    ids = [f\"doc_{j}\" for j in range(i, end_idx)]\n",
    "\n",
    "    # GPU batch encoding\n",
    "    vectors = model_embedding.encode(\n",
    "        documents,\n",
    "        batch_size=encode_batch_size,\n",
    "        show_progress_bar=False,\n",
    "        normalize_embeddings=True\n",
    "    ).tolist()\n",
    "\n",
    "    # Insert into Chroma\n",
    "    collection.add(\n",
    "        ids=ids,\n",
    "        documents=documents,\n",
    "        embeddings=vectors,\n",
    "        metadatas=metadatas\n",
    "    )\n",
    "\n",
    "print(\"✅ Embedding and storage to ChromaDB completed.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f2e2ccc9-b772-45f7-8258-cbc4f9c3ed59",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Now flush and clean\n",
    "print(\"🧹 Cleaning up and saving ChromaDB...\")\n",
    "client = None\n",
    "gc.collect()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c35d2fab-583f-4527-a7cc-9d31214b2f35",
   "metadata": {},
   "source": [
    "Our ChromaDB is currently saved in a persistent Google Drive path; for a production-ready app, we recommend uploading it to AWS S3 for better reliability and scalability.\n",
    "\n",
    "🧩 Now that we've generated the E5 embeddings, let's use them for both **XGBoost regression** and **GPT-4o Mini with RAG** ."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "40e4c587-211d-4bc0-91cf-6267f45405d6",
   "metadata": {
    "id": "40e4c587-211d-4bc0-91cf-6267f45405d6"
   },
   "source": [
    "## 📈 Embedding-Based Regression with XGBoost"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f058ccac-3392-457d-b54c-6471960e9af3",
   "metadata": {
    "id": "f058ccac-3392-457d-b54c-6471960e9af3"
   },
   "outputs": [],
   "source": [
    "# Step 1: Load vectors and prices from Chroma\n",
    "result = collection.get(include=['embeddings', 'documents', 'metadatas'])\n",
    "vectors = np.array(result['embeddings'])\n",
    "documents = result['documents']\n",
    "prices = [meta['price'] for meta in result['metadatas']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "JYQo0RaMb8Ql",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 254
    },
    "id": "JYQo0RaMb8Ql",
    "outputId": "c1641347-1fd4-41bb-e060-147224fc6bed"
   },
   "outputs": [],
   "source": [
    "# Step 2: Train XGBoost model\n",
    "xgb_model = XGBRegressor(n_estimators=100, random_state=42, n_jobs=-1, verbosity=0)\n",
    "xgb_model.fit(vectors, prices)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "yaqG0z7jb919",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "yaqG0z7jb919",
    "outputId": "6a2f9120-97e0-4436-aa12-40d94fbc5c64"
   },
   "outputs": [],
   "source": [
    "# Step 3: Serialize XGBoost model locally for Hugging Face upload\n",
    "MODEL_DIR = os.path.join(ROOT, \"models\")\n",
    "MODEL_FILENAME = \"xgboost_model.pkl\"\n",
    "LOCAL_MODEL = os.path.join(MODEL_DIR, MODEL_FILENAME)\n",
    "\n",
    "os.makedirs(MODEL_DIR, exist_ok=True)\n",
    "joblib.dump(xgb_model, LOCAL_MODEL)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "Z_17sQUdxIr3",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 104,
     "referenced_widgets": [
      "2362f3121e5546b98e4623eb3680e96b",
      "ef53ee3b68c840d6a3fe98386d26bbd9",
      "a4768d0ecdd640a2a5bccd07a93c54b7",
      "e177440016974bc699b666fa721c6490",
      "2a9d0e5829174b738b4dfea1c71a3481",
      "ee6dffc7b79e405d923940166ef10590",
      "57bf3388622241869a5e9dab558dca72",
      "aa87f4feddd6409fbfb81f417e5d6662",
      "973a83ca118e4ed1b5a51821034ecc31",
      "d5a3c955aba14b3ea8e9b5c90a3bf20a",
      "daaa4f26bad545a394685e266f85a6ae"
     ]
    },
    "id": "Z_17sQUdxIr3",
    "outputId": "68ebdbdb-d42e-4bc8-addc-85b42d418d1d"
   },
   "outputs": [],
   "source": [
    "# Step 4: Push serialized XGBoost model to Hugging Face Hub\n",
    "api = HfApi(token=hf_token)\n",
    "REPO_NAME = \"smart-deal-finder-models\"\n",
    "REPO_ID = f\"{HF_USER}/{REPO_NAME}\"\n",
    "\n",
    "# Create the model repo if it doesn't exist\n",
    "api.create_repo(repo_id=REPO_ID, repo_type=\"model\", private=True, exist_ok=True)\n",
    "\n",
    "# Upload the saved model\n",
    "api.upload_file(\n",
    "    path_or_fileobj=LOCAL_MODEL,\n",
    "    path_in_repo=MODEL_FILENAME,\n",
    "    repo_id=REPO_ID,\n",
    "    repo_type=\"model\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f59125d-9fa6-483b-957f-4423a9b2c900",
   "metadata": {
    "id": "3f59125d-9fa6-483b-957f-4423a9b2c900"
   },
   "outputs": [],
   "source": [
    "# Step 5: Define the predictor\n",
    "def xgb_predictor(datapoint):\n",
    "    doc = description(datapoint)\n",
    "    vector = model_embedding.encode([doc], normalize_embeddings=True)[0]\n",
    "    return max(0, xgb_model.predict([vector])[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a890f1f0-d827-472f-a7a9-6c2cbe3d8341",
   "metadata": {
    "id": "a890f1f0-d827-472f-a7a9-6c2cbe3d8341"
   },
   "source": [
    "🔔 Reminder: In Part 2, XGBoost with Word2Vec (non-contextual embeddings) achieved:\n",
    "- Avg. Error: ~$107\n",
    "- RMSLE: 0.83\n",
    "- Accuracy: 29.20%\n",
    "\n",
    "🧪 Now, let’s see if contextual embeddings improve XGBoost."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "q-tIbVilTPxP",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 718
    },
    "id": "q-tIbVilTPxP",
    "outputId": "7c9043ef-a2c4-4933-b334-18d99690ba0f"
   },
   "outputs": [],
   "source": [
    "# Step 4: Run the Tester on a subset of test data\n",
    "tester = Tester(xgb_predictor, test)\n",
    "tester.run()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dcb09db0-7d69-40e1-a6e3-b92263e38f1e",
   "metadata": {
    "id": "dcb09db0-7d69-40e1-a6e3-b92263e38f1e"
   },
   "source": [
    "Xgb Predictor Error=$110.68 RMSLE=0.93 Hits=30.4%"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1ccd5d3f-98cd-45a8-951f-d6446062addc",
   "metadata": {
    "id": "1ccd5d3f-98cd-45a8-951f-d6446062addc"
   },
   "source": [
    "Results are nearly the same. In this setup, switching to contextual embeddings didn’t yield performance gains for XGBoost."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4db1051d-9a7e-4cec-87fc-0d77fd858ced",
   "metadata": {
    "id": "4db1051d-9a7e-4cec-87fc-0d77fd858ced"
   },
   "source": [
    "## 🚰 Retrieval-Augmented Pipeline – GPT-4o Mini\n",
    "\n",
    "- Preprocess: clean the input text (description(item))\n",
    "- Embed: generate embedding vector (get_embedding(item))\n",
    "- Retrieve: find similar items from ChromaDB (find_similar_items)\n",
    "- Build Prompt: create the LLM prompt using context and masked target (build_messages)\n",
    "- Predict: get price estimate from LLM (estimate_price)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "YPLxSn7eHp9N",
   "metadata": {
    "id": "YPLxSn7eHp9N"
   },
   "outputs": [],
   "source": [
    "test[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eFxFKNroNiyD",
   "metadata": {
    "id": "eFxFKNroNiyD"
   },
   "outputs": [],
   "source": [
    "# Step 1: Preprocess test item text\n",
    "# (uses the same `description(item)` function as during training)\n",
    "description(test[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "lxIEtSWYHqCT",
   "metadata": {
    "id": "lxIEtSWYHqCT"
   },
   "outputs": [],
   "source": [
    "# Step 2: Embed a test item\n",
    "def get_embedding(item):\n",
    "    return model_embedding.encode([description(item)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "y43prQsuHp_w",
   "metadata": {
    "id": "y43prQsuHp_w"
   },
   "outputs": [],
   "source": [
    "# Step 3: Query Chroma for similar items\n",
    "def find_similars(item):\n",
    "    results = collection.query(query_embeddings=get_embedding(item).astype(float).tolist(), n_results=5)\n",
    "    documents = results['documents'][0][:]\n",
    "    prices = [m['price'] for m in results['metadatas'][0][:]]\n",
    "    return documents, prices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "nxAOUFRkHp6v",
   "metadata": {
    "id": "nxAOUFRkHp6v"
   },
   "outputs": [],
   "source": [
    "documents, prices = find_similars(test[1])\n",
    "documents, prices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "djPoSk6sHo84",
   "metadata": {
    "id": "djPoSk6sHo84"
   },
   "outputs": [],
   "source": [
    "# Step 4: Format similar items as context\n",
    "def format_context(similars, prices):\n",
    "    message = \"To provide some context, here are some other items that might be similar to the item you need to estimate.\\n\\n\"\n",
    "    for similar, price in zip(similars, prices):\n",
    "        message += f\"Potentially related product:\\n{similar}\\nPrice is ${price:.2f}\\n\\n\"\n",
    "    return message"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "F3yxhnqSHp4C",
   "metadata": {
    "id": "F3yxhnqSHp4C"
   },
   "outputs": [],
   "source": [
    "print(format_context(documents, prices))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "pEJobsKNHqE8",
   "metadata": {
    "id": "pEJobsKNHqE8"
   },
   "outputs": [],
   "source": [
    "# Step 5: Mask the price in the test item\n",
    "def mask_price_value(text):\n",
    "    return re.sub(r\"(\\n\\nPrice is \\$).*\", r\"\\1\", text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "vLhBNVBNQAHS",
   "metadata": {
    "id": "vLhBNVBNQAHS"
   },
   "outputs": [],
   "source": [
    "# Step 6: Build LLM messages\n",
    "def build_messages(datapoint, similars, prices):\n",
    "\n",
    "    system_message = \"You estimate prices of items. Reply only with the price, no explanation.\"\n",
    "\n",
    "    context = format_context(similars, prices)\n",
    "\n",
    "    prompt = mask_price_value(datapoint[\"text\"])\n",
    "    prompt = prompt.replace(\" to the nearest dollar\", \"\").replace(\"\\n\\nPrice is $\", \"\")\n",
    "\n",
    "    user_prompt = context + \"And now the question for you:\\n\\n\" + prompt\n",
    "\n",
    "    return [\n",
    "        {\"role\": \"system\", \"content\": system_message},\n",
    "        {\"role\": \"user\", \"content\": user_prompt},\n",
    "        {\"role\": \"assistant\", \"content\": \"Price is $\"}\n",
    "    ]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "I94fNHfBHp1a",
   "metadata": {
    "id": "I94fNHfBHp1a"
   },
   "outputs": [],
   "source": [
    "build_messages(test[1], documents, prices)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5NfY_GAVHpy4",
   "metadata": {
    "id": "5NfY_GAVHpy4"
   },
   "outputs": [],
   "source": [
    "# Step 7: Run prediction\n",
    "def get_price(s):\n",
    "    s = s.replace('$','').replace(',','')\n",
    "    match = re.search(r\"[-+]?\\d*\\.\\d+|\\d+\", s)\n",
    "    return float(match.group()) if match else 0\n",
    "\n",
    "def gpt_4o_mini_rag(item):\n",
    "    documents, prices = find_similars(item)\n",
    "    response = openai.chat.completions.create(\n",
    "        model=\"gpt-4o-mini\",\n",
    "        messages=build_messages(item, documents, prices),\n",
    "        seed=42,\n",
    "        max_tokens=5\n",
    "    )\n",
    "    reply = response.choices[0].message.content\n",
    "    return get_price(reply)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "Pg-GJTT0HpwV",
   "metadata": {
    "id": "Pg-GJTT0HpwV"
   },
   "outputs": [],
   "source": [
    "print(test[1][\"price\"])\n",
    "print(gpt_4o_mini_rag(test[1]))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "54103ab4-d6dd-4c0b-add5-5d9741e934b4",
   "metadata": {
    "id": "54103ab4-d6dd-4c0b-add5-5d9741e934b4"
   },
   "source": [
    "🔔 Reminder: In Part 2, GPT-4o Mini (without RAG) achieved:\n",
    "- Avg. Error: ~$99\n",
    "- RMSLE: 0.75\n",
    "- Accuracy: 44.8%\n",
    "\n",
    "🧪 Let’s find out if RAG can boost GPT-4o Mini’s price prediction capabilities.\n",
    "  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "r0NGJupwHppF",
   "metadata": {
    "id": "r0NGJupwHppF"
   },
   "outputs": [],
   "source": [
    "Tester.test(gpt_4o_mini_rag, test)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "00545880-d9e1-4934-8008-b62c105d177b",
   "metadata": {
    "id": "00545880-d9e1-4934-8008-b62c105d177b"
   },
   "source": [
    "Gpt 4O Mini Rag Error=$59.54 RMSLE=0.42 Hits=69.2%"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2b9f46ae-92b5-4189-89b0-df88a600bb89",
   "metadata": {
    "id": "2b9f46ae-92b5-4189-89b0-df88a600bb89"
   },
   "source": [
    "🎉 **GPT-4o Mini + RAG shows clear gains:**  \n",
    "Average error dropped from **$99 → $59.54**, RMSLE from **0.75 → 0.42**, and accuracy rose from **48.8% → 69.2%**.  \n",
    "\n",
    "Adding retrieval-based context led to a strong performance boost for GPT-4o Mini.\n",
    "\n",
    "Now the question is — can fine-tuning push it even further, surpass RAG, and challenge larger models?\n",
    "\n",
    "🔜 See you in the [next notebook](https://github.com/lisekarimi/lexo/blob/main/09_part4_ft_gpt4omini.ipynb)"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "A100",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
